Config: Cleanup and organize functions
Remove access of private attributes and use safer functions. Also move generalized functions into utils files. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
0903f852db
commit
a09dd802c2
3 changed files with 47 additions and 15 deletions
|
|
@ -1,20 +1,10 @@
|
||||||
"""Argparser for overriding config values"""
|
"""Argparser for overriding config values"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from common.config_models import TabbyConfigModel
|
from common.config_models import TabbyConfigModel
|
||||||
|
from common.utils import is_list_type
|
||||||
|
|
||||||
def is_list_type(type_hint):
|
|
||||||
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is list:
|
|
||||||
return True
|
|
||||||
if hasattr(type_hint, "__args__"):
|
|
||||||
# Recursively check for lists inside type arguments
|
|
||||||
return any(is_list_type(arg) for arg in type_hint.__args__)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def add_field_to_group(group, field_name, field_type, field) -> None:
|
def add_field_to_group(group, field_name, field_type, field) -> None:
|
||||||
|
|
@ -26,6 +16,7 @@ def add_field_to_group(group, field_name, field_type, field) -> None:
|
||||||
"help": field.description if field.description else "No description available",
|
"help": field.description if field.description else "No description available",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# If the inner type contains a list, specify argparse as such
|
||||||
if is_list_type(field_type):
|
if is_list_type(field_type):
|
||||||
kwargs["nargs"] = "+"
|
kwargs["nargs"] = "+"
|
||||||
|
|
||||||
|
|
@ -63,7 +54,7 @@ def init_argparser() -> argparse.ArgumentParser:
|
||||||
|
|
||||||
def convert_args_to_dict(
|
def convert_args_to_dict(
|
||||||
args: argparse.Namespace, parser: argparse.ArgumentParser
|
args: argparse.Namespace, parser: argparse.ArgumentParser
|
||||||
) -> dict[str, dict[str, Any]]:
|
) -> dict:
|
||||||
"""Broad conversion of surface level arg groups to dictionaries"""
|
"""Broad conversion of surface level arg groups to dictionaries"""
|
||||||
|
|
||||||
arg_groups = {}
|
arg_groups = {}
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
|
||||||
|
|
||||||
|
|
||||||
class ConfigOverrideConfig(BaseModel):
|
class ConfigOverrideConfig(BaseModel):
|
||||||
|
"""Model for overriding a provided config file."""
|
||||||
|
|
||||||
# TODO: convert this to a pathlib.path?
|
# TODO: convert this to a pathlib.path?
|
||||||
config: Optional[str] = Field(
|
config: Optional[str] = Field(
|
||||||
None, description=("Path to an overriding config.yml file")
|
None, description=("Path to an overriding config.yml file")
|
||||||
|
|
@ -13,6 +15,9 @@ class ConfigOverrideConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class UtilityActions(BaseModel):
|
class UtilityActions(BaseModel):
|
||||||
|
"""Model used for arg actions."""
|
||||||
|
|
||||||
|
# YAML export options
|
||||||
export_config: Optional[str] = Field(
|
export_config: Optional[str] = Field(
|
||||||
None, description="generate a template config file"
|
None, description="generate a template config file"
|
||||||
)
|
)
|
||||||
|
|
@ -20,6 +25,7 @@ class UtilityActions(BaseModel):
|
||||||
"config_sample.yml", description="path to export configuration file to"
|
"config_sample.yml", description="path to export configuration file to"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# OpenAPI JSON export options
|
||||||
export_openapi: Optional[bool] = Field(
|
export_openapi: Optional[bool] = Field(
|
||||||
False, description="export openapi schema files"
|
False, description="export openapi schema files"
|
||||||
)
|
)
|
||||||
|
|
@ -29,6 +35,8 @@ class UtilityActions(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class NetworkConfig(BaseModel):
|
class NetworkConfig(BaseModel):
|
||||||
|
"""Model for network configuration."""
|
||||||
|
|
||||||
host: Optional[str] = Field("127.0.0.1", description=("The IP to host on"))
|
host: Optional[str] = Field("127.0.0.1", description=("The IP to host on"))
|
||||||
port: Optional[int] = Field(5000, description=("The port to host on"))
|
port: Optional[int] = Field(5000, description=("The port to host on"))
|
||||||
disable_auth: Optional[bool] = Field(
|
disable_auth: Optional[bool] = Field(
|
||||||
|
|
@ -47,6 +55,8 @@ class NetworkConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class LoggingConfig(BaseModel):
|
class LoggingConfig(BaseModel):
|
||||||
|
"""Model for logging configuration."""
|
||||||
|
|
||||||
log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging"))
|
log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging"))
|
||||||
log_generation_params: Optional[bool] = Field(
|
log_generation_params: Optional[bool] = Field(
|
||||||
False, description=("Enable generation parameter logging")
|
False, description=("Enable generation parameter logging")
|
||||||
|
|
@ -55,6 +65,8 @@ class LoggingConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
|
"""Model for LLM configuration."""
|
||||||
|
|
||||||
# TODO: convert this to a pathlib.path?
|
# TODO: convert this to a pathlib.path?
|
||||||
model_dir: str = Field(
|
model_dir: str = Field(
|
||||||
"models",
|
"models",
|
||||||
|
|
@ -201,6 +213,8 @@ class ModelConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class DraftModelConfig(BaseModel):
|
class DraftModelConfig(BaseModel):
|
||||||
|
"""Model for draft LLM model configuration."""
|
||||||
|
|
||||||
# TODO: convert this to a pathlib.path?
|
# TODO: convert this to a pathlib.path?
|
||||||
draft_model_dir: Optional[str] = Field(
|
draft_model_dir: Optional[str] = Field(
|
||||||
"models",
|
"models",
|
||||||
|
|
@ -239,6 +253,8 @@ class DraftModelConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class LoraInstanceModel(BaseModel):
|
class LoraInstanceModel(BaseModel):
|
||||||
|
"""Model representing an instance of a Lora."""
|
||||||
|
|
||||||
name: str = Field(..., description=("Name of the LoRA model"))
|
name: str = Field(..., description=("Name of the LoRA model"))
|
||||||
scaling: float = Field(
|
scaling: float = Field(
|
||||||
1.0,
|
1.0,
|
||||||
|
|
@ -248,6 +264,8 @@ class LoraInstanceModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class LoraConfig(BaseModel):
|
class LoraConfig(BaseModel):
|
||||||
|
"""Model for lora configuration."""
|
||||||
|
|
||||||
# TODO: convert this to a pathlib.path?
|
# TODO: convert this to a pathlib.path?
|
||||||
lora_dir: Optional[str] = Field(
|
lora_dir: Optional[str] = Field(
|
||||||
"loras", description=("Directory to look for LoRAs (default: 'loras')")
|
"loras", description=("Directory to look for LoRAs (default: 'loras')")
|
||||||
|
|
@ -262,12 +280,16 @@ class LoraConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class SamplingConfig(BaseModel):
|
class SamplingConfig(BaseModel):
|
||||||
|
"""Model for sampling (overrides) config."""
|
||||||
|
|
||||||
override_preset: Optional[str] = Field(
|
override_preset: Optional[str] = Field(
|
||||||
None, description=("Select a sampler override preset")
|
None, description=("Select a sampler override preset")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DeveloperConfig(BaseModel):
|
class DeveloperConfig(BaseModel):
|
||||||
|
"""Model for developer settings configuration."""
|
||||||
|
|
||||||
unsafe_launch: Optional[bool] = Field(
|
unsafe_launch: Optional[bool] = Field(
|
||||||
False, description=("Skip Exllamav2 version check")
|
False, description=("Skip Exllamav2 version check")
|
||||||
)
|
)
|
||||||
|
|
@ -290,6 +312,8 @@ class DeveloperConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsConfig(BaseModel):
|
class EmbeddingsConfig(BaseModel):
|
||||||
|
"""Model for embeddings configuration."""
|
||||||
|
|
||||||
# TODO: convert this to a pathlib.path?
|
# TODO: convert this to a pathlib.path?
|
||||||
embedding_model_dir: Optional[str] = Field(
|
embedding_model_dir: Optional[str] = Field(
|
||||||
"models",
|
"models",
|
||||||
|
|
@ -310,6 +334,8 @@ class EmbeddingsConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class TabbyConfigModel(BaseModel):
|
class TabbyConfigModel(BaseModel):
|
||||||
|
"""Base model for a TabbyConfig."""
|
||||||
|
|
||||||
config: ConfigOverrideConfig = Field(
|
config: ConfigOverrideConfig = Field(
|
||||||
default_factory=ConfigOverrideConfig.model_construct
|
default_factory=ConfigOverrideConfig.model_construct
|
||||||
)
|
)
|
||||||
|
|
@ -331,6 +357,8 @@ class TabbyConfigModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
def generate_config_file(filename="config_sample.yml", indentation=2):
|
def generate_config_file(filename="config_sample.yml", indentation=2):
|
||||||
|
"""Creates a config.yml file from Pydantic models."""
|
||||||
|
|
||||||
schema = TabbyConfigModel.model_json_schema()
|
schema = TabbyConfigModel.model_json_schema()
|
||||||
|
|
||||||
def dump_def(id: str, indent=2):
|
def dump_def(id: str, indent=2):
|
||||||
|
|
@ -356,6 +384,3 @@ def generate_config_file(filename="config_sample.yml", indentation=2):
|
||||||
|
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
f.write(yaml)
|
f.write(yaml)
|
||||||
|
|
||||||
|
|
||||||
# generate_config_file("test.yml")
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
"""Common utility functions"""
|
"""Common utility functions"""
|
||||||
|
|
||||||
|
from typing import get_args, get_origin
|
||||||
|
|
||||||
|
|
||||||
def unwrap(wrapped, default=None):
|
def unwrap(wrapped, default=None):
|
||||||
"""Unwrap function for Optionals."""
|
"""Unwrap function for Optionals."""
|
||||||
|
|
@ -43,3 +45,17 @@ def flat_map(input_list):
|
||||||
"""Flattens a list of lists into a single list."""
|
"""Flattens a list of lists into a single list."""
|
||||||
|
|
||||||
return [item for sublist in input_list for item in sublist]
|
return [item for sublist in input_list for item in sublist]
|
||||||
|
|
||||||
|
|
||||||
|
def is_list_type(type_hint):
|
||||||
|
"""Checks if a type contains a list."""
|
||||||
|
|
||||||
|
if get_origin(type_hint) is list:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Recursively check for lists inside type arguments
|
||||||
|
type_args = get_args(type_hint)
|
||||||
|
if type_args:
|
||||||
|
return any(is_list_type(arg) for arg in type_args)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue