From 92af6567052737d166d4d40877c3b4a0de682c5e Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sun, 15 Sep 2024 17:50:37 +0100 Subject: [PATCH] improve config generation action --- common/actions.py | 2 +- common/config_models.py | 84 +++++++++++++++++++++++++---------------- 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/common/actions.py b/common/actions.py index 44fb236..ebf5539 100644 --- a/common/actions.py +++ b/common/actions.py @@ -20,7 +20,7 @@ def branch_to_actions() -> bool: ) elif config.actions.export_config: - generate_config_file(config.actions.config_export_path) + generate_config_file(filename=config.actions.config_export_path) else: # did not branch diff --git a/common/config_models.py b/common/config_models.py index bf2787a..23e7320 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -1,11 +1,25 @@ -from pydantic import AliasChoices, BaseModel, ConfigDict, Field +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, PrivateAttr from typing import List, Literal, Optional, Union from pathlib import Path +from pydantic_core import PydanticUndefined + CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] -class ConfigOverrideConfig(BaseModel): +class Metadata(BaseModel): + """metadata model for config options""" + + include_in_config: Optional[bool] = Field(True) + + +class BaseConfigModel(BaseModel): + """Base model for config models with added metadata""" + + _metadata: Metadata = PrivateAttr(Metadata()) + + +class ConfigOverrideConfig(BaseConfigModel): """Model for overriding a provided config file.""" # TODO: convert this to a pathlib.path? @@ -13,8 +27,10 @@ class ConfigOverrideConfig(BaseModel): None, description=("Path to an overriding config.yml file") ) + _metadata: Metadata = PrivateAttr(Metadata(include_in_config=False)) -class UtilityActions(BaseModel): + +class UtilityActions(BaseConfigModel): """Model used for arg actions.""" # YAML export options @@ -33,8 +49,10 @@ class UtilityActions(BaseModel): "openapi.json", description="path to export openapi schema to" ) + _metadata: Metadata = PrivateAttr(Metadata(include_in_config=False)) -class NetworkConfig(BaseModel): + +class NetworkConfig(BaseConfigModel): """Model for network configuration.""" host: Optional[str] = Field("127.0.0.1", description=("The IP to host on")) @@ -54,7 +72,7 @@ class NetworkConfig(BaseModel): # TODO: Migrate config.yml to have the log_ prefix # This is a breaking change. -class LoggingConfig(BaseModel): +class LoggingConfig(BaseConfigModel): """Model for logging configuration.""" log_prompt: Optional[bool] = Field( @@ -74,7 +92,7 @@ class LoggingConfig(BaseModel): ) -class ModelConfig(BaseModel): +class ModelConfig(BaseConfigModel): """Model for LLM configuration.""" # TODO: convert this to a pathlib.path? @@ -219,10 +237,11 @@ class ModelConfig(BaseModel): ), ) + _metadata: Metadata = PrivateAttr(Metadata()) model_config = ConfigDict(protected_namespaces=()) -class DraftModelConfig(BaseModel): +class DraftModelConfig(BaseConfigModel): """Model for draft LLM model configuration.""" # TODO: convert this to a pathlib.path? @@ -262,7 +281,7 @@ class DraftModelConfig(BaseModel): ) -class LoraInstanceModel(BaseModel): +class LoraInstanceModel(BaseConfigModel): """Model representing an instance of a Lora.""" name: str = Field(..., description=("Name of the LoRA model")) @@ -273,7 +292,7 @@ class LoraInstanceModel(BaseModel): ) -class LoraConfig(BaseModel): +class LoraConfig(BaseConfigModel): """Model for lora configuration.""" # TODO: convert this to a pathlib.path? @@ -289,7 +308,7 @@ class LoraConfig(BaseModel): ) -class SamplingConfig(BaseModel): +class SamplingConfig(BaseConfigModel): """Model for sampling (overrides) config.""" override_preset: Optional[str] = Field( @@ -297,7 +316,7 @@ class SamplingConfig(BaseModel): ) -class DeveloperConfig(BaseModel): +class DeveloperConfig(BaseConfigModel): """Model for developer settings configuration.""" unsafe_launch: Optional[bool] = Field( @@ -321,7 +340,7 @@ class DeveloperConfig(BaseModel): ) -class EmbeddingsConfig(BaseModel): +class EmbeddingsConfig(BaseConfigModel): """Model for embeddings configuration.""" # TODO: convert this to a pathlib.path? @@ -366,30 +385,29 @@ class TabbyConfigModel(BaseModel): model_config = ConfigDict(validate_assignment=True, protected_namespaces=()) -def generate_config_file(filename="config_sample.yml", indentation=2): +def generate_config_file( + model: BaseConfigModel = None, + filename: str = "config_sample.yml", + indentation: int = 2, +) -> None: """Creates a config.yml file from Pydantic models.""" - schema = TabbyConfigModel.model_json_schema() - - def dump_def(id: str, indent=2): - yaml = "" - indent = " " * indentation * indent - id = id.split("/")[-1] - - section = schema["$defs"][id]["properties"] - for property in section.keys(): # get type - comment = section[property]["description"] - yaml += f"{indent}# {comment}\n" - - value = section[property].get("default", "") - yaml += f"{indent}{property}: {value}\n\n" - - return yaml + "\n" - + schema = model if model else TabbyConfigModel() yaml = "" - for section in schema["properties"].keys(): - yaml += f"{section}:\n" - yaml += dump_def(schema["properties"][section]["$ref"]) + + for field, field_data in schema.model_fields.items(): + subfield_model = field_data.default_factory() + if not subfield_model._metadata.include_in_config: + continue + + yaml += f"# {subfield_model.__doc__}\n" + yaml += f"{field}:\n" + for subfield, subfield_data in subfield_model.model_fields.items(): + value = subfield_data.default + value = value if value is not None else "" + value = value if value is not PydanticUndefined else "" + yaml += f"{' ' * indentation}# {subfield_data.description}\n" + yaml += f"{' ' * indentation}{subfield}: {value}\n" yaml += "\n" with open(filename, "w") as f: