Config: Cleanup and organize functions
Remove access of private attributes and use safer functions. Also move generalized functions into utils files. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
0903f852db
commit
a09dd802c2
3 changed files with 47 additions and 15 deletions
|
|
@ -1,20 +1,10 @@
|
|||
"""Argparser for overriding config values"""
|
||||
|
||||
import argparse
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.config_models import TabbyConfigModel
|
||||
|
||||
|
||||
def is_list_type(type_hint):
|
||||
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is list:
|
||||
return True
|
||||
if hasattr(type_hint, "__args__"):
|
||||
# Recursively check for lists inside type arguments
|
||||
return any(is_list_type(arg) for arg in type_hint.__args__)
|
||||
return False
|
||||
from common.utils import is_list_type
|
||||
|
||||
|
||||
def add_field_to_group(group, field_name, field_type, field) -> None:
|
||||
|
|
@ -26,6 +16,7 @@ def add_field_to_group(group, field_name, field_type, field) -> None:
|
|||
"help": field.description if field.description else "No description available",
|
||||
}
|
||||
|
||||
# If the inner type contains a list, specify argparse as such
|
||||
if is_list_type(field_type):
|
||||
kwargs["nargs"] = "+"
|
||||
|
||||
|
|
@ -63,7 +54,7 @@ def init_argparser() -> argparse.ArgumentParser:
|
|||
|
||||
def convert_args_to_dict(
|
||||
args: argparse.Namespace, parser: argparse.ArgumentParser
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
) -> dict:
|
||||
"""Broad conversion of surface level arg groups to dictionaries"""
|
||||
|
||||
arg_groups = {}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
|
|||
|
||||
|
||||
class ConfigOverrideConfig(BaseModel):
|
||||
"""Model for overriding a provided config file."""
|
||||
|
||||
# TODO: convert this to a pathlib.path?
|
||||
config: Optional[str] = Field(
|
||||
None, description=("Path to an overriding config.yml file")
|
||||
|
|
@ -13,6 +15,9 @@ class ConfigOverrideConfig(BaseModel):
|
|||
|
||||
|
||||
class UtilityActions(BaseModel):
|
||||
"""Model used for arg actions."""
|
||||
|
||||
# YAML export options
|
||||
export_config: Optional[str] = Field(
|
||||
None, description="generate a template config file"
|
||||
)
|
||||
|
|
@ -20,6 +25,7 @@ class UtilityActions(BaseModel):
|
|||
"config_sample.yml", description="path to export configuration file to"
|
||||
)
|
||||
|
||||
# OpenAPI JSON export options
|
||||
export_openapi: Optional[bool] = Field(
|
||||
False, description="export openapi schema files"
|
||||
)
|
||||
|
|
@ -29,6 +35,8 @@ class UtilityActions(BaseModel):
|
|||
|
||||
|
||||
class NetworkConfig(BaseModel):
|
||||
"""Model for network configuration."""
|
||||
|
||||
host: Optional[str] = Field("127.0.0.1", description=("The IP to host on"))
|
||||
port: Optional[int] = Field(5000, description=("The port to host on"))
|
||||
disable_auth: Optional[bool] = Field(
|
||||
|
|
@ -47,6 +55,8 @@ class NetworkConfig(BaseModel):
|
|||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
"""Model for logging configuration."""
|
||||
|
||||
log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging"))
|
||||
log_generation_params: Optional[bool] = Field(
|
||||
False, description=("Enable generation parameter logging")
|
||||
|
|
@ -55,6 +65,8 @@ class LoggingConfig(BaseModel):
|
|||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""Model for LLM configuration."""
|
||||
|
||||
# TODO: convert this to a pathlib.path?
|
||||
model_dir: str = Field(
|
||||
"models",
|
||||
|
|
@ -201,6 +213,8 @@ class ModelConfig(BaseModel):
|
|||
|
||||
|
||||
class DraftModelConfig(BaseModel):
|
||||
"""Model for draft LLM model configuration."""
|
||||
|
||||
# TODO: convert this to a pathlib.path?
|
||||
draft_model_dir: Optional[str] = Field(
|
||||
"models",
|
||||
|
|
@ -239,6 +253,8 @@ class DraftModelConfig(BaseModel):
|
|||
|
||||
|
||||
class LoraInstanceModel(BaseModel):
|
||||
"""Model representing an instance of a Lora."""
|
||||
|
||||
name: str = Field(..., description=("Name of the LoRA model"))
|
||||
scaling: float = Field(
|
||||
1.0,
|
||||
|
|
@ -248,6 +264,8 @@ class LoraInstanceModel(BaseModel):
|
|||
|
||||
|
||||
class LoraConfig(BaseModel):
|
||||
"""Model for lora configuration."""
|
||||
|
||||
# TODO: convert this to a pathlib.path?
|
||||
lora_dir: Optional[str] = Field(
|
||||
"loras", description=("Directory to look for LoRAs (default: 'loras')")
|
||||
|
|
@ -262,12 +280,16 @@ class LoraConfig(BaseModel):
|
|||
|
||||
|
||||
class SamplingConfig(BaseModel):
|
||||
"""Model for sampling (overrides) config."""
|
||||
|
||||
override_preset: Optional[str] = Field(
|
||||
None, description=("Select a sampler override preset")
|
||||
)
|
||||
|
||||
|
||||
class DeveloperConfig(BaseModel):
|
||||
"""Model for developer settings configuration."""
|
||||
|
||||
unsafe_launch: Optional[bool] = Field(
|
||||
False, description=("Skip Exllamav2 version check")
|
||||
)
|
||||
|
|
@ -290,6 +312,8 @@ class DeveloperConfig(BaseModel):
|
|||
|
||||
|
||||
class EmbeddingsConfig(BaseModel):
|
||||
"""Model for embeddings configuration."""
|
||||
|
||||
# TODO: convert this to a pathlib.path?
|
||||
embedding_model_dir: Optional[str] = Field(
|
||||
"models",
|
||||
|
|
@ -310,6 +334,8 @@ class EmbeddingsConfig(BaseModel):
|
|||
|
||||
|
||||
class TabbyConfigModel(BaseModel):
|
||||
"""Base model for a TabbyConfig."""
|
||||
|
||||
config: ConfigOverrideConfig = Field(
|
||||
default_factory=ConfigOverrideConfig.model_construct
|
||||
)
|
||||
|
|
@ -331,6 +357,8 @@ class TabbyConfigModel(BaseModel):
|
|||
|
||||
|
||||
def generate_config_file(filename="config_sample.yml", indentation=2):
|
||||
"""Creates a config.yml file from Pydantic models."""
|
||||
|
||||
schema = TabbyConfigModel.model_json_schema()
|
||||
|
||||
def dump_def(id: str, indent=2):
|
||||
|
|
@ -356,6 +384,3 @@ def generate_config_file(filename="config_sample.yml", indentation=2):
|
|||
|
||||
with open(filename, "w") as f:
|
||||
f.write(yaml)
|
||||
|
||||
|
||||
# generate_config_file("test.yml")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""Common utility functions"""
|
||||
|
||||
from typing import get_args, get_origin
|
||||
|
||||
|
||||
def unwrap(wrapped, default=None):
|
||||
"""Unwrap function for Optionals."""
|
||||
|
|
@ -43,3 +45,17 @@ def flat_map(input_list):
|
|||
"""Flattens a list of lists into a single list."""
|
||||
|
||||
return [item for sublist in input_list for item in sublist]
|
||||
|
||||
|
||||
def is_list_type(type_hint):
|
||||
"""Checks if a type contains a list."""
|
||||
|
||||
if get_origin(type_hint) is list:
|
||||
return True
|
||||
|
||||
# Recursively check for lists inside type arguments
|
||||
type_args = get_args(type_hint)
|
||||
if type_args:
|
||||
return any(is_list_type(arg) for arg in type_args)
|
||||
|
||||
return False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue